from openai import OpenAI
from PIL import Image
import os
import base64
import argparse
import glob
import json


TEMP_PATH = 'temp'


def create_temp():
    """
    Creates a temporary directory if it does not exist.

    The temporary directory is used to store resized and converted images temporarily.
    The directory path is defined by the TEMP_PATH constant.

    Parameters:
    None

    Returns:
    None

    Raises:
    None
    """
    if not os.path.exists(TEMP_PATH):
        os.makedirs(TEMP_PATH)


def remove_temp_file(path):
    """
    Removes a temporary file at the specified path.

    Parameters:
    path (str): The path to the temporary file to be removed.

    Returns:
    None

    Raises:
    FileNotFoundError: If the file does not exist at the specified path.
    PermissionError: If the file cannot be removed due to insufficient permissions.
    """
    os.remove(path)
    

def resize_and_convert_to_base64(image_path, max_size=(768, 768), image_extension='PNG'):
    """
    Resizes and converts an image to base64 format.

    This function opens an image file, resizes it to fit within the specified maximum size while maintaining the aspect ratio,
    converts the resized image to the specified format, and encodes it as base64. The temporary file used for the resized image
    is then removed.

    Parameters:
    image_path (str): The path to the image file to be resized and converted.
    max_size (tuple, optional): The maximum size (width, height) to which the image should be resized. Defaults to (768, 768).
    image_extension (str, optional): The format to which the image should be converted. Defaults to 'PNG'.

    Returns:
    str: The base64-encoded representation of the resized and converted image.

    Raises:
    FileNotFoundError: If the image file does not exist at the specified path.
    IOError: If there is an error reading or writing the image file.
    """
    # Open the image file
    with Image.open(image_path) as img:
        temp_path = TEMP_PATH + '/temp_' + image_path.split('/')[1]
        # Calculate the new size preserving the aspect ratio
        img.thumbnail(max_size, Image.LANCZOS)
        
        # Save the image
        img.save(temp_path, format=image_extension)

        with open(temp_path, "rb") as resized_image_file:
            remove_temp_file(temp_path)
            return base64.b64encode(resized_image_file.read()).decode('utf-8')
        

def create_metadata_jsonl(image_paths, texts, data_focus='statue', output_file='data/metadata.jsonl'):
    """
    Creates a JSONL file containing metadata for a dataset of images.

    The function takes a list of image paths, a list of corresponding captions,
    a data focus (default is 'statue'), and an output file path (default is 'data/metadata.jsonl').
    It checks if the lengths of the image paths and texts lists are equal,
    and raises a ValueError if they are not.
    It then creates a JSONL file with each line containing a JSON object with 'file_name' and 'prompt' keys.
    The 'prompt' key's value is a caption prefix followed by the corresponding text.

    Parameters:
    image_paths (List[str]): A list of paths to the images.
    texts (List[str]): A list of corresponding captions for the images.
    data_focus (str, optional): The main object in the dataset. Defaults to 'statue'.
    output_file (str, optional): The path to the output JSONL file. Defaults to 'data/metadata.jsonl'.

    Returns:
    None

    Raises:
    ValueError: If the lengths of image_paths and texts are not equal.
    """
    if len(image_paths) != len(texts):
        raise ValueError("The length of image paths and texts must be the same.")
    
    caption_prefix = 'a photo of CUS ' + data_focus + ', '
    with open(output_file, 'w') as f:
        for image_path, text in zip(image_paths, texts):
            # Create a dictionary for each pair
            data = {
                "file_name": image_path.split('/')[-1],
                "prompt": caption_prefix + text
            }
            # Write the dictionary as a JSON object in a new line
            f.write(json.dumps(data) + '\n')


def annotate_image(image_path, image_extension, openai_secret_key=None):
    """
    Annotates an image using OpenAI's GPT-4 model.

    This function resizes and converts the image to base64 format,
    then sends the image and a description to the GPT-4 model for annotation.
    The function retrieves the annotation from the model's response.

    Parameters:
    image_path (str): The path to the image file to be annotated.
    image_extension (str): The extension of the image file.
    openai_secret_key (str, optional): The OpenAI secret key. If not provided, the function will look for it in the environment variable 'OPENAI_SECRET_KEY'.

    Returns:
    str: The annotation generated by the GPT-4 model.

    Raises:
    ValueError: If the OpenAI secret key is not provided and not found in the environment variable.
    """
    image_base64 = resize_and_convert_to_base64(image_path, image_extension=image_extension)
    description = """Directly describe with brevity and as brief as possible the provided image.
    without any introductory phrase like 'This image shows', 'In the scene',
    'This image depicts' or similar phrases. Just start describing the scene please. Do not end the caption with a '.'
    Good examples: a cat on a windowsill, a photo of smiling cactus in an office, a man and baby sitting by a window, a photo of wheel on a car.
    """

    secret_key = os.environ.get("OPENAI_SECRET_KEY")
    if secret_key is None and openai_secret_key is not None:
        secret_key = openai_secret_key
    else:
        raise ValueError("No OpenAI secret key found. Please set the OPENAI_SECRET_KEY environment variable or pass it as an argument.")

    client = OpenAI(
        api_key=secret_key,
    )
    response = client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": description},
                    {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
                ]
            }
        ],
        model="gpt-4o",
    )
    
    return response.choices[0].message.content


def main():
    # Create the parser
    parser = argparse.ArgumentParser(description='Annotate data')

    # Add the arguments
    parser.add_argument('--images_path', type=str, required=True, help='Path to the images.')
    parser.add_argument('--image_extension', type=str, required=True, help='Extension of the images.')
    parser.add_argument('--secret_key', type=str, help='Secret key as a string.')
    parser.add_argument('--focus', type=str, help='Type the main object in the dataset. Like statue or kid')

    # Parse the arguments
    args = parser.parse_args()

    create_temp()
    annotations = []
    image_paths = []

    dataset_path =args.images_path + '/*.' + args.image_extension
    for path in glob.glob(dataset_path):
        annotated_image = annotate_image(path, args.image_extension.upper(), openai_secret_key=None if args.secret_key is None else args.secret_key)
        annotations.append(annotated_image)
        image_paths.append(path)

    # Create metadata
    create_metadata_jsonl(image_paths, annotations, data_focus=args.focus, output_file=args.images_path + '/metadata.jsonl')


if __name__ == "__main__":
    main()
